Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MarkAliasesPrepare applies bookend from inputs as well as outputs. #2815

Closed
wants to merge 15 commits into from

Conversation

wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Aug 20, 2024

As a follow-up to #2639. #2639 bookends from the outputs, and this PR bookends from the inputs.

Fixes #2599.
Fixes #2577.

Benchmark results are mostly neutral. test_nanogpt_layer_norm got slower because of the added host latency. This is expected because, compared with #2639, latency of ops bookended from the inputs is more likely to appear on the critical path. See benchmark details below.

$ nvidia-smi -L
GPU 0: NVIDIA A100 80GB PCIe (UUID: GPU-d9e8abeb-4f1a-5cd0-f825-a45f7ea57875)
$ python tools/benchmark_thunder.py --storage ~/workspace --sync main:main main:wjy/bookend
$ pytest-benchmark --storage ~/workspace compare 0027 0028 --group-by name

https://gist.github.com/wujingyue/fc11f9725bc510827d9c53061ca05898

0027 -- without the PR
0028 -- with the PR

Below are the nvprof traces of test_nanogpt_layer_norm[inference-thunder] without and with this PR.

Without the PR:

without-2815

With the PR:

with-2815

@wujingyue wujingyue marked this pull request as draft August 20, 2024 22:11
@wujingyue
Copy link
Collaborator Author

cc @jjsjann123 in case you need it for experiments while I'm cleaning this up for review.

@wujingyue
Copy link
Collaborator Author

!build

@wujingyue wujingyue requested a review from jjsjann123 August 21, 2024 06:27
@wujingyue wujingyue marked this pull request as ready for review August 21, 2024 06:27
csrc/ops/alias.cpp Outdated Show resolved Hide resolved
csrc/ops/alias.h Outdated Show resolved Hide resolved
Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks a lot for the improvement and leaving out refactor for an easy review.

if (std::all_of(first_user, last_user, [](const Use& use) {
return ir_utils::isSegmentSet(use.user);
})) {
if (std::all_of(users.begin(), users.end(), ir_utils::isSegmentSet)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👏

use_of->uses().size()) {
std::vector<Expr*> users;
users.reserve(use_of->uses().size());
// `uses_to_segment` is sorted so `nullptr` if exists appears first.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: uses_to_segment only exist at the caller scope.

// There are a few corner cases where we can avoid adding a
// `segment_set`. If a segment_set is to be added between `use_of` and all
// its users, ...
if (users.size() == use_of->uses().size()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we check for this is to use segment_seg to avoid horizontal fusion of common consumers of a fusion input?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for cases like:

in --> [meta 0] -> out0
   |
   +-> [meta 1] -> out1

No segment_sets are needed at all.

// The following emulates the bookend optimization. This is done in two
// steps: the first step bookends the outputs and the second step does the
// inputs. TODO(wujingyue): extract this into a function. I'm adding the new
// logic in place just to make review easier.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙇

// +--> reshape_1 ----+
//
// If we separate `reshape_0` and `reshape_1` from `mul`, the pointwise
// kernel would take double the input.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to point out here that, double the input means double the number of TensorViews passed to the kernel. But in kernel execution, the redundant read might be saved from cache.
But I agree this is still an overall good thing to have in the reshape example.

But this would stop us bookending multiple slicing on the same tensor.... arguably, in that case we aren't really duplicating memory buffer, if the slices aren't overlapping.

Not pushing to change this in the PR. Can we add a comment on that?

NVM, I see you special case about SliceOp down below.

@wujingyue
Copy link
Collaborator Author

wujingyue commented Aug 26, 2024

This PR as is will slow down test_nanogpt_layer_norm[inference-thunder]. Here's an nvFuser-only reproducer:

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T3 = fd.ops.cast(T0, dtype=DataType.Float)
    T4, T5 = fd.ops.var_mean(T3, dims=[2], correction=0, keepdim=False)
    S6 = fd.define_scalar(16, dtype=DataType.Int)
    S7 = fd.define_scalar(128, dtype=DataType.Int)
    S8 = fd.define_scalar(1, dtype=DataType.Int)
    V9 = fd.define_vector([S6, S7, S8], dtype=DataType.Int)
    T10 = fd.ops.broadcast_in_dim(T4, shape=V9, broadcast_dims=[0, 1])
    S11 = fd.define_scalar(16, dtype=DataType.Int)
    S12 = fd.define_scalar(128, dtype=DataType.Int)
    S13 = fd.define_scalar(1, dtype=DataType.Int)
    V14 = fd.define_vector([S11, S12, S13], dtype=DataType.Int)
    T15 = fd.ops.broadcast_in_dim(T5, shape=V14, broadcast_dims=[0, 1])
    S16 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T17 = fd.ops.add(T10, S16)
    T18 = fd.ops.rsqrt(T17)
    S19 = fd.define_scalar(16, dtype=DataType.Int)
    S20 = fd.define_scalar(128, dtype=DataType.Int)
    S21 = fd.define_scalar(1600, dtype=DataType.Int)
    V22 = fd.define_vector([S19, S20, S21], dtype=DataType.Int)
    T23 = fd.ops.broadcast_in_dim(T15, shape=V22, broadcast_dims=[0, 1, 2])
    T24 = fd.ops.sub(T3, T23)
    S25 = fd.define_scalar(16, dtype=DataType.Int)
    S26 = fd.define_scalar(128, dtype=DataType.Int)
    S27 = fd.define_scalar(1600, dtype=DataType.Int)
    V28 = fd.define_vector([S25, S26, S27], dtype=DataType.Int)
    T29 = fd.ops.broadcast_in_dim(T18, shape=V28, broadcast_dims=[0, 1, 2])
    T30 = fd.ops.mul(T24, T29)
    S31 = fd.define_scalar(16, dtype=DataType.Int)
    S32 = fd.define_scalar(128, dtype=DataType.Int)
    S33 = fd.define_scalar(1600, dtype=DataType.Int)
    V34 = fd.define_vector([S31, S32, S33], dtype=DataType.Int)
    T35 = fd.ops.broadcast_in_dim(T1, shape=V34, broadcast_dims=[2])
    T36 = fd.ops.cast(T35, dtype=DataType.Float)
    T37 = fd.ops.mul(T30, T36)
    S38 = fd.define_scalar(16, dtype=DataType.Int)
    S39 = fd.define_scalar(128, dtype=DataType.Int)
    S40 = fd.define_scalar(1600, dtype=DataType.Int)
    V41 = fd.define_vector([S38, S39, S40], dtype=DataType.Int)
    T42 = fd.ops.broadcast_in_dim(T2, shape=V41, broadcast_dims=[2])
    T43 = fd.ops.cast(T42, dtype=DataType.Float)
    T44 = fd.ops.add(T37, T43)
    T45 = fd.ops.cast(T44, dtype=DataType.BFloat16)
    fd.add_output(T45)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn(3276800, dtype=torch.bfloat16, device='cuda:0').as_strided((16, 128, 1600), (204800, 1600, 1)),
    torch.randn(1600, dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
    torch.randn(1600, dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
]
fd.execute(inputs)

Without this PR

$ nsys nvprof --print-gpu-trace python repro.py

 Start (ns)  Duration (ns)  CorrId  GrdX  GrdY  GrdZ  BlkX  BlkY  BlkZ  Reg/Trd  StcSMem (MB)  DymSMem (MB)  Bytes (MB)  Throughput (MB/s)  SrcMemKd  DstMemKd                Device                Ctx  GreenCtx  Strm                                                  Name
 ----------  -------------  ------  ----  ----  ----  ----  ----  ----  -------  ------------  ------------  ----------  -----------------  --------  --------  ----------------------------------  ---  --------  ----  ----------------------------------------------------------------------------------------------------
 1633827617           6368     418  2048     1     1   128     1     1       40         0.000         0.001                                                     NVIDIA RTX 6000 Ada Generation (0)    1               7  <unnamed>::nvfuser_inner_persistent_f0_c1_r0_g0(<unnamed>::Tensor<<unnamed>::__bfloat, (int)3, (int…

CUDA kernel: https://gist.github.com/wujingyue/fb8ab7f75eb6cb6f3a33bde4aaf652fa

With this PR

 Start (ns)  Duration (ns)  CorrId  GrdX  GrdY  GrdZ  BlkX  BlkY  BlkZ  Reg/Trd  StcSMem (MB)  DymSMem (MB)  Bytes (MB)  Throughput (MB/s)  SrcMemKd  DstMemKd                Device                Ctx  GreenCtx  Strm                                                  Name
 ----------  -------------  ------  ----  ----  ----  ----  ----  ----  -------  ------------  ------------  ----------  -----------------  --------  --------  ----------------------------------  ---  --------  ----  ----------------------------------------------------------------------------------------------------
 1708225834          11072     592  2048     1     1   416     1     1       28         0.000         0.002                                                     NVIDIA RTX 6000 Ada Generation (0)    1               7  <unnamed>::nvfuser_inner_persistent_f0_c1_r0_g2(<unnamed>::Tensor<<unnamed>::__bfloat, (int)3, (int…

CUDA kernel: https://gist.github.com/wujingyue/22ee624a6a387d29885ab480864c5c86

@wujingyue
Copy link
Collaborator Author

wujingyue commented Aug 26, 2024

I suspect the regression is caused by vectorization. max allowed vectorize_factor becomes 1 with the PR.

$ _bn && NVFUSER_DUMP=scheduler_params python repro.py

Without the PR:

===== Persistent Kernel Properties ========
inner_most_dimension_numel: 1600
total_reduction_numel: 1600
total_iteration_numel: 2048
max_persistent_buffer_size: 3200
n_tensor_inputs: 3
max_input_dtype_size: 2
max allowed vectorize_factor: 8
project_persistent_buffers: 1

With the PR:

===== Persistent Kernel Properties ========
inner_most_dimension_numel: 1600
total_reduction_numel: 1600
total_iteration_numel: 2048
max_persistent_buffer_size: 3200
n_tensor_inputs: 3
max_input_dtype_size: 2
max allowed vectorize_factor: 1
project_persistent_buffers: 1

cc @liqiangxl

@wujingyue
Copy link
Collaborator Author

3ff26bc should fix the vectorization factor problem. I'll try to separate it out to a different PR. Thanks @jjsjann123 for the help offline!

@wujingyue
Copy link
Collaborator Author

!build

@liqiangxl
Copy link
Collaborator

3ff26bc should fix the vectorization factor problem. I'll try to separate it out to a different PR. Thanks @jjsjann123 for the help offline!

May I get a note of your offline discussion? Just want to understand why this PR influences vectorization. Thanks!

@wujingyue
Copy link
Collaborator Author

May I get a note of your offline discussion? Just want to understand why this PR influences vectorization. Thanks!

See #2854. Before that PR, nvFuser has trouble vectorizing the input of shape {1 ex 16, 1 ex 128, 1600}. This problem is only triggered by bookending the leading broadcast+expand; otherwise the input would be a plain {1600}.

@wujingyue
Copy link
Collaborator Author

@kevinstephano, @tfogal, and @jjsjann123: I'd like to hear your opinions on this. As said in the description, this PR is mostly performance neutral but does slow down test_nanogpt_layer_norm microbenchmarks due to the added host latency. Options I can think of are:

  1. Bite the bullet and merge the PR.
  2. Tune down nvFuser's bookending optimization1 to bookend certain op types (e.g. slice and squeeze) that are useful for @jjsjann123 's RoPE work and don't affect layernorm.
  3. Close this PR and revisit when the host latency is improved.
  4. Close this PR and wait for this effort to handle slices without having to bookend.

Footnotes

  1. I'm referring to MarkAliasesPrepare, the pre-segmenter pass this PR changes, not Thunder's bookending, which has been disabled by default.

@tfogal
Copy link
Collaborator

tfogal commented Aug 29, 2024

@kevinstephano, @tfogal, and @jjsjann123: I'd like to hear your opinions on this. As said in the description, this PR is mostly performance neutral but does slow down test_nanogpt_layer_norm microbenchmarks due to the added host latency.

This is incredibly timely!

I just added a PR for some NeVA benchmarks: Lightning-AI/lightning-thunder#1064. How does this impact g1, g6, and g13? The latter two appear to be RoPE or part of RoPE, and g1 is a monster but notably includes layer_norm.

My feeling is that if this improves those 3 graphs, I'd argue it's a clear "merge now". Those are real-world cases on a model we care about in the near term. test_nanogpt_layer_norm is just a microbenchmark.

@wujingyue
Copy link
Collaborator Author

@tfogal g1, g6 and g13 are all performance neutral. g1 seems to be a huge trace where only a small portion is on nvFuser. g6 and g13 have torch.cat so are picked up by the cat executor. I also tried to disable the cat executor and found performance still neutral.

That being said, I found more interesting results when benchmarking test_litgpt_qkv_split_rope with the cat executor off (it's on by default):

  1. Inference and forward got a speedup, the largest being 2x for
    -------------------------------------------------------------------------------------- benchmark 'test_litgpt_qkv_split_rope[Llama-2-13b-hf-inference-bs2-thunder]': 2 tests --------------------------------------------------------------------------------------
    Name (time in us)                                                                          Min                   Max                  Mean             StdDev                Median               IQR            Outliers         OPS            Rounds  Iterations
    -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
    test_litgpt_qkv_split_rope[Llama-2-13b-hf-inference-bs2-thunder] (0034_thunder)       536.9694 (1.0)        687.6239 (1.0)        546.7848 (1.0)      24.1743 (1.0)        541.3722 (1.0)      2.2948 (1.0)         45;49  1,828.8729 (1.0)         934           2
    test_litgpt_qkv_split_rope[Llama-2-13b-hf-inference-bs2-thunder] (0033_thunder)     1,165.2717 (2.17)     1,482.8211 (2.16)     1,188.9340 (2.17)     52.9914 (2.19)     1,176.5324 (2.17)     4.9081 (2.14)        45;55    841.0896 (0.46)        858           1
    -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
    
  2. Backward got a slowdown, the largest being 1.4x for
    ---------------------------------------------------------------------------- benchmark 'test_litgpt_qkv_split_rope[Llama-3-70B-backward-bs2-thunder]': 2 tests -----------------------------------------------------------------------------
    Name (time in ms)                                                                  Min               Max              Mean            StdDev            Median               IQR            Outliers       OPS            Rounds  Iterations
    --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
    test_litgpt_qkv_split_rope[Llama-3-70B-backward-bs2-thunder] (0033_thunder)     3.9539 (1.0)      4.2474 (1.0)      3.9899 (1.0)      0.0647 (1.0)      3.9735 (1.0)      0.0090 (1.0)         16;18  250.6324 (1.0)         253           1
    test_litgpt_qkv_split_rope[Llama-3-70B-backward-bs2-thunder] (0034_thunder)     5.6405 (1.43)     6.0317 (1.42)     5.6803 (1.42)     0.0673 (1.04)     5.6634 (1.43)     0.0120 (1.32)        11;12  176.0460 (0.70)        178           1
    --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
    

I've yet to figure out the reason behind the slowdown in backprop. The complete fusion looks much more complicated:
image

I noticed some inefficiencies, e.g. tensor 18, a fp32 tensor, will be segmented with this PR. However, none of them sound significant enough to explain the 1.4x.

@wujingyue
Copy link
Collaborator Author

Anyhow, I don't think this PR is a clear win, and I'd probably wait until we understand RoPE backprop better.

@wujingyue wujingyue marked this pull request as draft August 29, 2024 21:41
@tfogal
Copy link
Collaborator

tfogal commented Aug 29, 2024

g1, g6 and g13 are all performance neutral

Thanks for your analysis! I appreciate you thinking to disable the cat executor too, I should have thought to mention that earlier.

It would be great, of course, to understand the 1.4x slowdown, and endeavor to get this either neutral or beneficial in all cases. But I think that's also a high bar, and if it's not impacting the real-world g* cases + is a wash on existing microbenchmark, then 🤷.

I'd probably wait until we understand RoPE backprop better.

Sounds good! But if you start getting merge difficulties I don't think it would be dire to merge it, at least if you think this is a helpful intermediate step towards:

Tune down nvFuser's bookending optimization

or

wait for this effort to handle slices without having to bookend

@jjsjann123
Copy link
Collaborator

Anyhow, I don't think this PR is a clear win, and I'd probably wait until we understand RoPE backprop better.

😢

Backward got a slowdown, the largest being 1.4x for

QQ: what's the baseline for the comparison? I can't see it clearly in the graph, but did the alias pass in this PR remove the permutation on allocation domain of the outputs on the fusion?

@wujingyue
Copy link
Collaborator Author

The comparison has always been between with this PR and without. However, in the latest comparison, I turned off the cat ex for both baseline and test.

@wujingyue wujingyue added the enhancement New feature or request label Sep 7, 2024
@wujingyue wujingyue closed this Sep 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Suboptimal segmentation for RoPE alias analysis missing out opportunities on aliasing within fusion segments
4 participants